{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 第11课：动手实战基于 LSTM 轻松生成各种古诗"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "puncs = [']', '[', '（', '）', '{', '}', '：', '《', '》']\n",
    "def preprocess_file(Config):\n",
    "    # 语料文本内容\n",
    "    files_content = ''\n",
    "    with open(Config.poetry_file, 'r', encoding='utf-8') as f:\n",
    "        for line in f:\n",
    "            # 每行的末尾加上\"]\"符号代表一首诗结束\n",
    "            for char in puncs:\n",
    "                line = line.replace(char, \"\")\n",
    "            files_content += line.strip() + \"]\"\n",
    "\n",
    "    words = sorted(list(files_content))\n",
    "    words.remove(']')\n",
    "    counted_words = {}\n",
    "    for word in words:\n",
    "        if word in counted_words:\n",
    "            counted_words[word] += 1\n",
    "        else:\n",
    "            counted_words[word] = 1\n",
    "\n",
    "    # 去掉低频的字\n",
    "    erase = []\n",
    "    for key in counted_words:\n",
    "        if counted_words[key] <= 2:\n",
    "            erase.append(key)\n",
    "    for key in erase:\n",
    "        del counted_words[key]\n",
    "    del counted_words[']']\n",
    "    wordPairs = sorted(counted_words.items(), key=lambda x: -x[1])\n",
    "\n",
    "    words, _ = zip(*wordPairs)\n",
    "    # word到id的映射\n",
    "    word2num = dict((c, i + 1) for i, c in enumerate(words))\n",
    "    num2word = dict((i, c) for i, c in enumerate(words))\n",
    "    word2numF = lambda x: word2num.get(x, 0)\n",
    "    return word2numF, num2word, words, files_content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Config(object):\n",
    "    poetry_file = '../data/11/poetry.txt'\n",
    "    weight_file = '../data/11/poetry_model.h5'\n",
    "    # 根据前六个字预测第七个字\n",
    "    max_len = 6\n",
    "    batch_size = 512\n",
    "    learning_rate = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "\n",
    "import keras\n",
    "from keras.models import load_model\n",
    "from keras import Model\n",
    "from keras.callbacks import LambdaCallback\n",
    "from keras.layers import *\n",
    "from keras.optimizers import Adam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PoetryModel(object):\n",
    "        def __init__(self, config):\n",
    "            self.model = None\n",
    "            self.do_train = True\n",
    "            self.loaded_model = False\n",
    "            self.config = config\n",
    "\n",
    "            # 文件预处理\n",
    "            self.word2numF, self.num2word, self.words, self.files_content = preprocess_file(self.config)\n",
    "            if os.path.exists(self.config.weight_file):\n",
    "                self.model = load_model(self.config.weight_file)\n",
    "                self.model.summary()\n",
    "            else:\n",
    "                self.train()\n",
    "            self.do_train = False\n",
    "            self.loaded_model = True\n",
    "\n",
    "        def build_model(self):\n",
    "            '''建立模型'''\n",
    "            input_tensor = Input(shape=(self.config.max_len,))\n",
    "            embedd = Embedding(len(self.num2word)+1, 300, input_length=self.config.max_len)(input_tensor)\n",
    "            lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)\n",
    "            dropout = Dropout(0.6)(lstm)\n",
    "            lstm = Bidirectional(GRU(128, return_sequences=True))(embedd)\n",
    "            dropout = Dropout(0.6)(lstm)\n",
    "            flatten = Flatten()(lstm)\n",
    "            dense = Dense(len(self.words), activation='softmax')(flatten)\n",
    "            self.model = Model(inputs=input_tensor, outputs=dense)\n",
    "            optimizer = Adam(lr=self.config.learning_rate)\n",
    "            self.model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])\n",
    "\n",
    "        def sample(self, preds, temperature=1.0):\n",
    "            '''\n",
    "            当temperature=1.0时，模型输出正常\n",
    "            当temperature=0.5时，模型输出比较open\n",
    "            当temperature=1.5时，模型输出比较保守\n",
    "            在训练的过程中可以看到temperature不同，结果也不同\n",
    "            '''\n",
    "            preds = np.asarray(preds).astype('float64')\n",
    "            preds = np.log(preds) / temperature\n",
    "            exp_preds = np.exp(preds)\n",
    "            preds = exp_preds / np.sum(exp_preds)\n",
    "            probas = np.random.multinomial(1, preds, 1)\n",
    "            return np.argmax(probas)\n",
    "\n",
    "        def generate_sample_result(self, epoch, logs):  \n",
    "            print(\"\\n==================Epoch {}=====================\".format(epoch))\n",
    "            for diversity in [0.5, 1.0, 1.5]:\n",
    "                print(\"------------Diversity {}--------------\".format(diversity))\n",
    "                start_index = random.randint(0, len(self.files_content) - self.config.max_len - 1)\n",
    "                generated = ''\n",
    "                sentence = self.files_content[start_index: start_index + self.config.max_len]\n",
    "                generated += sentence\n",
    "                for i in range(20):\n",
    "                    x_pred = np.zeros((1, self.config.max_len))\n",
    "                    for t, char in enumerate(sentence[-6:]):\n",
    "                        x_pred[0, t] = self.word2numF(char)\n",
    "\n",
    "                    preds = self.model.predict(x_pred, verbose=0)[0]\n",
    "                    next_index = self.sample(preds, diversity)\n",
    "                    next_char = self.num2word[next_index]\n",
    "                    generated += next_char\n",
    "                    sentence = sentence + next_char\n",
    "                print(sentence)\n",
    "\n",
    "        def predict(self, text):\n",
    "            if not self.loaded_model:\n",
    "                return\n",
    "            with open(self.config.poetry_file, 'r', encoding='utf-8') as f:\n",
    "                file_list = f.readlines()\n",
    "            random_line = random.choice(file_list)\n",
    "            # 如果给的text不到四个字，则随机补全\n",
    "            if not text or len(text) != 4:\n",
    "                for _ in range(4 - len(text)):\n",
    "                    random_str_index = random.randrange(0, len(self.words))\n",
    "                    text += self.num2word.get(random_str_index) if self.num2word.get(random_str_index) not in [',', '。',\n",
    "                                                                                                               '，'] else self.num2word.get(\n",
    "                        random_str_index + 1)\n",
    "            seed = random_line[-(self.config.max_len):-1]\n",
    "            res = ''\n",
    "            seed = 'c' + seed\n",
    "            for c in text:\n",
    "                seed = seed[1:] + c\n",
    "                for j in range(5):\n",
    "                    x_pred = np.zeros((1, self.config.max_len))\n",
    "                    for t, char in enumerate(seed):\n",
    "                        x_pred[0, t] = self.word2numF(char)\n",
    "                    preds = self.model.predict(x_pred, verbose=0)[0]\n",
    "                    next_index = self.sample(preds, 1.0)\n",
    "                    next_char = self.num2word[next_index]\n",
    "                    seed = seed[1:] + next_char\n",
    "                res += seed\n",
    "            return res\n",
    "\n",
    "        def data_generator(self):\n",
    "            i = 0\n",
    "            while 1:\n",
    "                x = self.files_content[i: i + self.config.max_len]\n",
    "                y = self.files_content[i + self.config.max_len]\n",
    "                puncs = [']', '[', '（', '）', '{', '}', '：', '《', '》', ':']\n",
    "                if len([i for i in puncs if i in x]) != 0:\n",
    "                    i += 1\n",
    "                    continue\n",
    "                if len([i for i in puncs if i in y]) != 0:\n",
    "                    i += 1\n",
    "                    continue\n",
    "                y_vec = np.zeros(\n",
    "                    shape=(1, len(self.words)),\n",
    "                    dtype=np.bool\n",
    "                )\n",
    "                y_vec[0, self.word2numF(y)] = 1.0\n",
    "                x_vec = np.zeros(\n",
    "                    shape=(1, self.config.max_len),\n",
    "                    dtype=np.int32\n",
    "                )\n",
    "                for t, char in enumerate(x):\n",
    "                    x_vec[0, t] = self.word2numF(char)\n",
    "                yield x_vec, y_vec\n",
    "                i += 1\n",
    "                \n",
    "        def train(self):\n",
    "            #number_of_epoch = len(self.files_content) // self.config.batch_size\n",
    "            number_of_epoch = 10\n",
    "            if not self.model:\n",
    "                self.build_model()\n",
    "            self.model.summary()\n",
    "            self.model.fit_generator(\n",
    "                generator=self.data_generator(),\n",
    "                verbose=True,\n",
    "                steps_per_epoch=self.config.batch_size,\n",
    "                epochs=number_of_epoch,\n",
    "                callbacks=[\n",
    "                    keras.callbacks.ModelCheckpoint(self.config.weight_file, save_weights_only=False),\n",
    "                    LambdaCallback(on_epoch_end=self.generate_sample_result)\n",
    "                ]\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "_________________________________________________________________\n",
      "Layer (type)                 Output Shape              Param #   \n",
      "=================================================================\n",
      "input_2 (InputLayer)         (None, 6)                 0         \n",
      "_________________________________________________________________\n",
      "embedding_2 (Embedding)      (None, 6, 300)            1797300   \n",
      "_________________________________________________________________\n",
      "bidirectional_4 (Bidirection (None, 6, 256)            329472    \n",
      "_________________________________________________________________\n",
      "flatten_2 (Flatten)          (None, 1536)              0         \n",
      "_________________________________________________________________\n",
      "dense_2 (Dense)              (None, 5990)              9206630   \n",
      "=================================================================\n",
      "Total params: 11,333,402\n",
      "Trainable params: 11,333,402\n",
      "Non-trainable params: 0\n",
      "_________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "model = PoetryModel(Config)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##这里有个坑，mac下tensorflow无法通过pip直接安装1.2之后的版本，一直包max_itera**错误，需要指定tensorflow1.4.0，keras2.1.2版本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "text:小雨\n",
      "小棨浓赓捶。雨匙腊狈爟。伍炅酿委沔。锼剑诗堂扞。\n"
     ]
    }
   ],
   "source": [
    "text = input(\"text:\")\n",
    "sentence = model.predict(text)\n",
    "print(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
